{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training DRKG Using TransE_L2\n",
    "This notebook shows how to train DRKG embeddings using TransE_L2\n",
    "\n",
    "Before training the model, you need to download the original DRKG source file into your local storage, e.g., ./data/drkg.tsv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install DGL-KE\n",
    "Before training the model, we need to install dgl and dgl-ke packages as well as other dependencies. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install torch\n",
    "!pip3 install dgl==0.4.3post2 \n",
    "!pip3 install dglke"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare train/valid/test set\n",
    "Before training, we need to split the original drkg into train/valid/test set with a 9:0.5:0.5 manner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.insert(1, '../utils')\n",
    "from utils import download_and_extract\n",
    "download_and_extract()\n",
    "drkg_file = '../data/drkg/drkg.tsv'\n",
    "\n",
    "df = pd.read_csv(drkg_file, sep=\"\\t\", header=None)\n",
    "triples = df.values.tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get 5,869,293 triples, now we will split them into three files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_triples = len(triples)\n",
    "num_triples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Please make sure the output directory exist.\n",
    "seed = np.arange(num_triples)\n",
    "np.random.shuffle(seed)\n",
    "\n",
    "train_cnt = int(num_triples * 0.9)\n",
    "valid_cnt = int(num_triples * 0.05)\n",
    "train_set = seed[:train_cnt]\n",
    "train_set = train_set.tolist()\n",
    "valid_set = seed[train_cnt:train_cnt+valid_cnt].tolist()\n",
    "test_set = seed[train_cnt+valid_cnt:].tolist()\n",
    "\n",
    "with open(\"train/drkg_train.tsv\", 'w+') as f:\n",
    "    for idx in train_set:\n",
    "        f.writelines(\"{}\\t{}\\t{}\\n\".format(triples[idx][0], triples[idx][1], triples[idx][2]))\n",
    "        \n",
    "with open(\"train/drkg_valid.tsv\", 'w+') as f:\n",
    "    for idx in valid_set:\n",
    "        f.writelines(\"{}\\t{}\\t{}\\n\".format(triples[idx][0], triples[idx][1], triples[idx][2]))\n",
    "\n",
    "with open(\"train/drkg_test.tsv\", 'w+') as f:\n",
    "    for idx in test_set:\n",
    "        f.writelines(\"{}\\t{}\\t{}\\n\".format(triples[idx][0], triples[idx][1], triples[idx][2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training TransE_l2 model\n",
    "We can training the TransE_l2 model by simplying using DGL-KE command line. For more information about using DGL-KE please refer to https://github.com/awslabs/dgl-ke.\n",
    "\n",
    "Here we train the model using 8 GPUs on an AWS p3.16xlarge instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!DGLBACKEND=pytorch dglke_train --dataset DRKG --data_path ./train --data_files drkg_train.tsv drkg_valid.tsv drkg_test.tsv --format 'raw_udd_hrt' --model_name TransE_l2 --batch_size 2048 \\\n",
    "--neg_sample_size 256 --hidden_dim 400 --gamma 12.0 --lr 0.1 --max_step 100000 --log_interval 1000 --batch_size_eval 16 -adv --regularization_coef 1.00E-07 --test --num_thread 1 --gpu 0 1 2 3 4 5 6 7 --num_proc 8 --neg_sample_size_eval 10000 --async_update"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Entity and Relation Embeddings\n",
    "The resulting model, i.e., the entity and relation embeddings can be found under ./ckpts. (Please refer to the first line of the training log for the specific location.)\n",
    "\n",
    "The overall process will generate 4 important files:\n",
    "\n",
    "  - Entity embedding: ./ckpts/<model\\_name>_<dataset\\_name>_<run_\\id>/xxx\\_entity.npy\n",
    "  - Relation embedding: ./ckpts/<model\\_name>_<dataset\\_name>_<run\\_id>/xxx\\_relation.npy\n",
    "  - The entity id mapping, formated in <entity\\_name> <entity\\_id> pair: <data\\_path>/entities.tsv\n",
    "  - The relation id mapping, formated in <relation\\_name> <relation\\_id> pair: <data\\_path>/relations.tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls ./ckpts/TransE_l2_DRKG_0/\n",
    "!ls ./train/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A Glance of the Entity and Relation Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_entity.npy')\n",
    "relation_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_relation.npy')\n",
    "\n",
    "print(node_emb.shape)\n",
    "print(relation_emb.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
